{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# simple model——如何使用建模框架\n",
    "\n",
    "使用建模框架实现两个模型：\n",
    "\n",
    "- char分词attention模型\n",
    "- jieba分词attention模型\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# 1. char分词建模\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "## train\n",
    "### 参数设置——utils\n",
    "- 数据：cmn.txt\n",
    "- 切词：char\n",
    "- 句子：500\n",
    "\n",
    "### 参数设置——train\n",
    "\n",
    "- outputdir: logs/char\n",
    "- epochs: 50\n",
    "- unit type: lstm\n",
    "- units num: 512\n",
    "- num layer: 2\n",
    "- attention: luong\n",
    "- optimizer: sdg\n",
    "- learnrate: 1\n",
    "- keepprob : 0.8\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 包含数据处理函数\n",
    "from utils import GenData\n",
    "# 包含模型参数文件\n",
    "from params import create_hparams\n",
    "# 模型文件\n",
    "from model import BaseModel\n",
    "\n",
    "data = GenData('cmn.txt','char',500)\n",
    "param = create_hparams()\n",
    "param.out_dir = 'logs/char'\n",
    "param.encoder_vocab_size = len(data.id2en)\n",
    "param.decoder_vocab_size = len(data.id2ch)\n",
    "\n",
    "model = BaseModel(param, 'train')\n",
    "model.train(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "## infer\n",
    "### 参数设置——utils\n",
    "- 同train\n",
    "\n",
    "### 参数设置——infer\n",
    "\n",
    "- 解码方法：greedy（beam search有bug）\n",
    "\n",
    "set `param.batch_size = 1`\n",
    "\n",
    "set `model = BaseModel(param, 'infer')`\n",
    "\n",
    "use `model.inference(data)` make inference work"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "outputs": [],
   "source": [
    "from utils import GenData\n",
    "from params import create_hparams\n",
    "from model import BaseModel\n",
    "\n",
    "def main():\n",
    "    data = GenData('cmn.txt','char',500)\n",
    "    param = create_hparams()\n",
    "    param.out_dir = 'logs/char'\n",
    "    param.encoder_vocab_size = len(data.id2en)\n",
    "    param.decoder_vocab_size = len(data.id2ch)\n",
    "\n",
    "    # infer模式下需要改动\n",
    "    param.batch_size = 1\n",
    "    param.keepprob = 1\n",
    "\n",
    "    model = BaseModel(param, 'infer')\n",
    "    model.inference(data)\n",
    "\n",
    "main()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# 2. jieba分词建模\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "## train\n",
    "### 参数设置——utils\n",
    "- 数据：cmn.txt\n",
    "- 切词：jieba\n",
    "- 句子：200\n",
    "\n",
    "### 参数设置——train\n",
    "\n",
    "- outputdir: logs/jieba\n",
    "- epochs: 50\n",
    "- unit type: lstm\n",
    "- units num: 512\n",
    "- num layer: 2\n",
    "- attention: luong\n",
    "- optimizer: sdg\n",
    "- learnrate: 1\n",
    "- keepprob : 0.8\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "outputs": [],
   "source": [
    "# 包含数据处理函数\n",
    "from utils import GenData\n",
    "# 包含模型参数文件\n",
    "from params import create_hparams\n",
    "# 模型文件\n",
    "from model import BaseModel\n",
    "\n",
    "data = GenData('cmn.txt','jieba',500)\n",
    "param = create_hparams()\n",
    "param.out_dir = 'logs/jieba'\n",
    "param.encoder_vocab_size = len(data.id2en)\n",
    "param.decoder_vocab_size = len(data.id2ch)\n",
    "\n",
    "model = BaseModel(param, 'train')\n",
    "model.train(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### 参数设置——utils\n",
    "- 同train\n",
    "\n",
    "### 参数设置——infer\n",
    "\n",
    "- 解码方法：greedy（beam search有bug）\n",
    "\n",
    "set `param.batch_size = 1`\n",
    "\n",
    "set `model = BaseModel(param, 'infer')`\n",
    "\n",
    "use `model.inference(data)` make inference work"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "outputs": [],
   "source": [
    "from utils import GenData\n",
    "from params import create_hparams\n",
    "from model import BaseModel\n",
    "\n",
    "def main():\n",
    "    data = GenData('cmn.txt','jieba',500)\n",
    "    param = create_hparams()\n",
    "    param.out_dir = 'logs/jieba'\n",
    "    param.encoder_vocab_size = len(data.id2en)\n",
    "    param.decoder_vocab_size = len(data.id2ch)\n",
    "\n",
    "    # infer模式下需要改动\n",
    "    param.batch_size = 1\n",
    "    param.keepprob = 1\n",
    "\n",
    "    model = BaseModel(param, 'infer')\n",
    "    model.inference(data)\n",
    "\n",
    "main()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# 3. hanlp分词建模\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## train\n",
    "### 参数设置——utils\n",
    "- 数据：cmn.txt\n",
    "- 切词：hanlp\n",
    "- 句子：200\n",
    "\n",
    "### 参数设置——train\n",
    "\n",
    "- outputdir: logs/hanlp\n",
    "- epochs: 50\n",
    "- unit type: lstm\n",
    "- units num: 512\n",
    "- num layer: 2\n",
    "- attention: luong\n",
    "- optimizer: sdg\n",
    "- learnrate: 1\n",
    "- keepprob : 0.8\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "outputs": [],
   "source": [
    "# 包含数据处理函数\n",
    "from utils import GenData\n",
    "# 包含模型参数文件\n",
    "from params import create_hparams\n",
    "# 模型文件\n",
    "from model import BaseModel\n",
    "\n",
    "data = GenData('cmn.txt','hanlp',500)\n",
    "param = create_hparams()\n",
    "param.out_dir = 'logs/hanlp'\n",
    "param.encoder_vocab_size = len(data.id2en)\n",
    "param.decoder_vocab_size = len(data.id2ch)\n",
    "\n",
    "model = BaseModel(param, 'train')\n",
    "model.train(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "### 参数设置——utils\n",
    "- 同train\n",
    "\n",
    "### 参数设置——infer\n",
    "\n",
    "- 解码方法：greedy（beam search有bug）\n",
    "\n",
    "set `param.batch_size = 1`\n",
    "\n",
    "set `model = BaseModel(param, 'infer')`\n",
    "\n",
    "use `model.inference(data)` make inference work"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "outputs": [],
   "source": [
    "from utils import GenData\n",
    "from params import create_hparams\n",
    "from model import BaseModel\n",
    "\n",
    "def main():\n",
    "    data = GenData('cmn.txt','hanlp',500)\n",
    "    param = create_hparams()\n",
    "    param.out_dir = 'logs/hanlp'\n",
    "    param.encoder_vocab_size = len(data.id2en)\n",
    "    param.decoder_vocab_size = len(data.id2ch)\n",
    "\n",
    "    # infer模式下需要改动\n",
    "    param.batch_size = 1\n",
    "    param.keepprob = 1\n",
    "\n",
    "    model = BaseModel(param, 'infer')\n",
    "    model.inference(data)\n",
    "\n",
    "main()"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Slideshow",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
